#!/usr/bin/env python3
# _*_ coding:utf-8 _*_

from lib.common import banner, get_base_path
import csv
import requests
import json
from loguru import logger
from typing import List, Union, Dict
from conf import settings
from concurrent.futures import ThreadPoolExecutor
requests.packages.urllib3.disable_warnings()


class AutoExploitSwagger:
    def __init__(self,
                 proxy_ip: str,
                 proxy_port: str
                 ):
        self._project_path: str = ''
        self._base_path: str = '' # 注意批量的时候要判断下host，不能直接用base_path
        self.all_paths:Union[Dict] = {}
        self._all_projects: Union[List, None] = None
        self._proxy_ip: str = proxy_ip
        self._proxy_port: str = proxy_port
        f = open(settings.result_path, 'w', newline='', encoding='utf-8')  # 接口信息写到csv中,不包含xray结果
        self.writer = csv.writer(f)

    @staticmethod
    def proxy_is_valid(proxy_ip: str, proxy_port: str):
        pass

    def _get_proxy(self) -> Dict:
        proxies = {
            "http": "http://{}:{}".format(self._proxy_ip, self._proxy_port),
            "https": "http://{}:{}".format(self._proxy_ip, self._proxy_port),
        }
        return proxies

    def _get_all_projects(self, target_url: str) -> Union[List, None]:
        resources_url: str = target_url + "/swagger-resources"
        res = requests.get(url=resources_url, proxies=self._get_proxy(), verify=False)
        self._all_projects = json.loads(res.text)
        return self._all_projects

    def _get_url(self, current_url: str) -> None:
        try:
            self._base_path = get_base_path(current_url)
            res = requests.get(url=current_url, proxies=self._get_proxy(), verify=False)
            paths = json.loads(res.text)['paths']
            logger.info("[+] : 此项目下共有 %d 个api接口" % (len(paths)))
            self.all_paths.update(paths)
        except Exception as e:
            logger.error(e)
            return None

    @staticmethod
    def get_param_type(parameter: Dict) -> tuple[str, bool]:
        if parameter.get('type') == "boolean":
            re = "true"
            ret = True
        else:
            re = "1"
            ret = False
        return re, ret

    def send_param(self, origin_dict: Dict, method: str, path: str) -> None:
        try:
            # tags = origin_dict['tags'][0]
            try:
                summary = origin_dict['summary']  # 可能不存在
            except Exception as e:
                summary = self._base_path + path
                logger.warning(e)
            # operation_id = origin_dict['operationId']
            if 'consumes' in origin_dict.keys():  # json格式
                consumes = origin_dict['consumes'][0]
            else:
                consumes = '0'
            if consumes != '0':
                json_array = {}
                if 'parameters' in origin_dict:
                    parameters = origin_dict['parameters']
                    # logger.info("接口参数个数为 %d" % (len(parameters)))
                    for parameter in parameters:
                        position = parameter.get('in')
                        if position == "header":
                            if parameter.get('type') == 'boolean':
                                settings.HEADERS[parameter['name']] = 'true'
                            else:
                                settings.HEADERS[parameter['name']] = settings.token  # 设置header为任意字符串，例如token,
                                # 这里只写了token（最常见）可在配置文件设置
                        else:
                            if parameter.get('type') == "boolean":
                                json_array[parameter['name']] = 'true'
                            else:
                                json_array[parameter['name']] = '1234'
                else:
                    logger.info(self._base_path + path + "----> 接口没有参数，接口参数个数为 %d" % 0)
                json_string = json.dumps(json_array)
                if method == "post":
                    if '{' in path:  # post /api/mee/v2/building/select/{id}
                        re = self.get_param_type(origin_dict['parameters'][0])
                        res = requests.post(url=self._base_path + path[:path.index('{')] + re[0],
                                            headers=settings.HEADERS, verify=False,
                                            proxies=self._get_proxy())
                    else:
                        res = requests.post(url=self._base_path + path, data=json_string,
                                            headers=settings.HEADERS, verify=False,
                                            proxies=self._get_proxy())

                    if 'parameters' in origin_dict.keys():
                        row = [self._project_path, summary, path, method, consumes, self._base_path+path,
                               str(len(origin_dict['parameters'])), json_string, res.status_code, res.text]
                    else:
                        row = [self._project_path, summary, path, method, consumes, self._base_path+path, '0', json_string,
                               res.status_code,
                               res.text]
                    self.writer.writerow(row)
                elif method == "put":
                    logger.warning("[!] {} 存有put方法! 请手动测试".format(self._base_path+path))
            else:  # {id} 参数或者非json格式
                if '{' in path:
                    parameter = origin_dict['parameters'][0]  # {'name': 'id', 'in': 'path', 'description': 'id',
                    # 'required': True, 'type': 'integer', 'format': 'int64'}
                    re = self.get_param_type(parameter)
                    if method == 'get':
                        res = requests.get(url=self._base_path + path[:path.index('{')] + re[0], verify=False,
                                           proxies=self._get_proxy())
                        row = [self._project_path, summary, path, method, consumes,
                               self._base_path + path[:path.index('{')],
                               str(len(origin_dict['parameters'])), "", res.status_code, res.text]
                        self.writer.writerow(row)
                    elif method == 'delete':
                        logger.warning("[!] {} 存有delete方法! 请手动测试".format(self._base_path+path))
                else:  # 无{} 需要参数拼接
                    query_string = ''
                    if 'parameters' in origin_dict:
                        parameters = origin_dict['parameters']
                        param_num = len(parameters)
                        for parameter in parameters:
                            try:
                                if self.get_param_type(parameter)[1]:
                                    query_string += "&%s=true" % (parameter['name'])
                                else:
                                    query_string += "&%s=1" % (parameter['name'])
                            except Exception as e:
                                query_string += "&%s={1}" % (parameter['name'])
                                logger.error(e)
                    else:
                        query_string = ''
                        param_num = 0
                    query_string = query_string[1:]
                    if method == "get":
                        res = requests.get(url=self._base_path + path + "?" + query_string, verify=False,
                                           proxies=self._get_proxy())
                        row = [self._project_path, summary, path, method, consumes, self._base_path + path + "?"
                               + query_string, str(param_num), "", res.status_code, res.text]
                        self.writer.writerow(row)
                    elif method == "delete":
                        logger.warning("[!] {} 存有delete方法! 请手动测试".format(self._base_path+path))

        except Exception as e:
            logger.warning(e)
            # logger.warning("post无参数！-> Whitelabel Error Page!")

    def get_all_urls(self, target_url: str) -> Dict:
        try:
            self._get_all_projects(target_url)
            projects_count = len(self._all_projects)
            logger.info("[+] 该URL里一共存在 %d 个项目" % projects_count)
            self.writer.writerow(
                ["Projects", "Summary", "Path", "Method", "Consumes", "URL", "ParamsNum", "Data", "STATUS_CODE",
                 "Response"])
        except Exception as e:
            logger.error(e)
        if self._all_projects is not None:
            for project in self._all_projects:
                current_url = target_url + project['url']  # 接口地址
                logger.info("[+] : 开始测试URL: %s " % current_url)
                self._project_path = current_url
                self._get_url(current_url)
            return self.all_paths
'''
class ExploitThread(threading.Thread):
    def __init__(self):
        threading.Thread.__init__(self)
    def run(self):
        pass
'''

def exploit_threads(Swagger:AutoExploitSwagger,all_paths:Dict,worker_num=10)->None:
    with ThreadPoolExecutor(max_workers=worker_num) as t:
        for path in all_paths:
            for method in all_paths[path]:
                args = [all_paths[path][method], method, path]
                t.submit(lambda p: Swagger.send_param(*p),args)

def read_urls(file):
    try:
        file_list = []
        with open(file, 'r', encoding='utf-8') as f:
            for line in f:
                file_list.append(line.strip('\n'))
            return set(file_list)
    except FileNotFoundError:
        print('无法打开指定的文件!')
        return ''
    except LookupError:
        return ''
    except UnicodeDecodeError:
        return ''
def run(args) -> None:
    logger.add(settings.logfile_path, format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}")
    banner()
    if args.target_url:
        auto_exploit_swagger = AutoExploitSwagger(args.proxy_ip, args.proxy_port)
        all_urls = auto_exploit_swagger.get_all_urls(args.target_url)  # all_urls dict
        exploit_threads(auto_exploit_swagger, all_urls, args.exploit_threads)

    elif args.url_file:
        url_list=read_urls(args.url_file)
        for url in url_list:
            auto_exploit_swagger = AutoExploitSwagger(args.proxy_ip, args.proxy_port)
            exploit_threads(auto_exploit_swagger, auto_exploit_swagger.get_all_urls(url), args.exploit_threads)
